import numpy as np
from numpy import linalg as la
import random
import math


def SGD(cost, grad, nexp, hess, K_record_times, gamma, x0, batch_size, n):   

    K = K_record_times[-1]

    #batches
    batch = []
    for i in range(K): batch.append(random.sample(range(n),batch_size))   

    ## initialization
    f = np.zeros((len(K_record_times),nexp))
    gammas = np.zeros((len(K_record_times),))
    
    for e in range(nexp):
        
        #batches        
        batch = []
        for i in range(K+1): batch.append(random.sample(range(n),batch_size))    

        #iterations
        i_record = 0
        x = [x0 for i in range(K+2)]
        gammas = np.zeros((K+2,))

        for k in range(K+1):

            gammas[k] = gamma/np.sqrt(k+1)

            if k == K_record_times[i_record]:
                    f[i_record, e] = cost(x[k],range(n))
                    i_record = i_record+1

            x[k+1] = x[k] - gammas[k]*grad(x[k],batch[k])

    name = 'SGD, step='+"{:.2f}".format(gamma)
    return name, f, gammas



def SGD_decr(cost,grad,hess,nexp, K_record_times, compute_hess, gamma_init, decr, th, x0, batch_size, n):
    #number of iterations
    K = K_record_times[-1]
    
    #init
    f = np.zeros((len(K_record_times),nexp))
    gammas_rec = np.zeros((len(K_record_times), nexp))
    x_rec = np.zeros((len(K_record_times), nexp, len(x0)))
    
    ## optimization
    for e in range(nexp):
        
        #batches        
        batch = []
        for i in range(K+1): batch.append(random.sample(range(n),batch_size))    

        #iterations
        i_record = 0
        x = [x0 for i in range(K+2)]
        gammas = np.zeros((K+2,))
        for k in range(K+1):
            # stepsize selection   
            if k<th:
                gammas[k] = gamma_init
            else:
                if decr == 'sqrt':
                    gammas[k] = gamma_init/math.sqrt(k-th+1)
                else:
                    gammas[k] = (gamma_init/(k-th+1))
            
            #record
            if k==K_record_times[i_record]:
                gammas_rec[i_record,e] = gammas[k]
                f[i_record,e] = cost(x[k],range(n))
                x_rec[i_record,e,:] = x[k]
                i_record = i_record+1
            # update
            x[k+1] = x[k] - gammas[k]*grad(x[k],batch[k])
           
    ## name    
    if decr == 'sqrt':    
        name = r'SGD, $\gamma_k='+"{:.2f}".format(gamma_init)+'/\sqrt{k+1}$'
    else:
        name = r'SGD, $\gamma_k='+"{:.2f}".format(gamma_init)+'/k$'
        
    return name, f, gammas_rec, x_rec


def SPS_max(cost, grad, hess, nexp, K_record_times, compute_hess, c, gamma0, x0, batch_size, n, alg):
    #number of iterations
    K = K_record_times[-1]
    
    #init
    f = np.zeros((len(K_record_times), nexp))
    mus = np.zeros((len(K_record_times), nexp))
    Ls = np.zeros((len(K_record_times), nexp))
    gammas_rec = np.zeros((len(K_record_times), nexp))
    x_rec = np.zeros((len(K_record_times), nexp, len(x0)))
    c_init = c
    
    ## optimization
    for e in range(nexp):
        
        #batches        
        batch = []
        for i in range(K+1): batch.append(random.sample(range(n), batch_size))    
        if alg in ['bound_2grad', 'bound_3grad']:
            batch2 = []
            for i in range(K+1): batch2.append(random.sample(range(n), batch_size))
        if alg == 'bound_3grad':
            batch3 = []
            for i in range(K+1): batch3.append(random.sample(range(n), batch_size))

        #iterations
        i_record = 0
        x = [x0 for i in range(K+2)]
        gammas = np.zeros((K+2,))

        for k in range(K+1):

            if alg == 'bound_3grad':
                sps_grad = cost(x[k],batch[k])/la.norm(grad(x[k],batch3[k]))**2
            else:
                sps_grad = cost(x[k],batch[k])/la.norm(grad(x[k],batch[k]))**2
            if alg in ['sqrt', 'bound_2grad', 'bound_3grad']:
                gammas[k] = min([sps_grad/c, gamma0/np.sqrt(k+1)])
            elif alg == 'bound':
                gammas[k] = min([sps_grad/c, gamma0/(k+1)])
            elif alg == 'max':
                gammas[k] = min([sps_grad/c, gamma0])

            #record
            if k == K_record_times[i_record]:
                gammas_rec[i_record, e] = gammas[k]
                f[i_record, e] = cost(x[k], range(n))
                x_rec[i_record, e, :] = x[k]
                if compute_hess:
                    mus[i_record, e], Ls[i_record, e] = hess(x[k]) 
                i_record = i_record+1
                    
            # update
            if alg in ['bound_2grad', 'bound_3grad']:
                x[k+1] = x[k] - gammas[k]*grad(x[k],batch2[k])
            else:
                x[k+1] = x[k] - gammas[k]*grad(x[k],batch[k])
           
    ## name       
    if alg == 'bound':     
        name = r'SPS$_{bound}$ - 1/k, $c='+"{:.2f}".format(c)+', \gamma_{0}='+"{:.2f}".format(gamma0)+'$'
    elif alg == 'sqrt':
        name = r'SPS$_{bound}$ - 1/sqrt(k), $c='+"{:.2f}".format(c)+', \gamma_{0}='+"{:.2f}".format(gamma0)+'$'
    elif alg == 'max':
        name = r'SPS$_{\max}$, $c='+"{:.2f}".format(c)+', \gamma_{0}='+"{:.2f}".format(gamma0)+'$'
    elif alg == 'bound_2grad':
        name = r'SPS2$_{bound}$ - 1/sqrt(k), $c='+"{:.2f}".format(c)+', \gamma_{0}='+"{:.2f}".format(gamma0)+'$'
    elif alg == 'bound_3grad':
        name = r'SPS3$_{bound}$ - 1/sqrt(k), $c='+"{:.2f}".format(c)+', \gamma_{0}='+"{:.2f}".format(gamma0)+'$'
    
    return name, f, gammas_rec, x_rec, mus, Ls


def SPS_decr(cost,grad,hess,nexp, K_record_times, compute_hess, c_init, decr, gamma_max, x0, batch_size, n):
    #number of iterations
    K = K_record_times[-1]
    
    #init
    f = np.zeros((len(K_record_times),nexp))
    gammas_rec = np.zeros((len(K_record_times),nexp))
    x_rec = np.zeros((len(K_record_times),nexp,len(x0)))
    
    ## optimization
    for e in range(nexp):
        
        #batches        
        batch = []
        for i in range(K+1): batch.append(random.sample(range(n),batch_size))    

        #iterations
        i_record = 0
        x = [x0 for i in range(K+2)]
        gammas = np.zeros((K+2,))
        c = np.zeros((K+2,))

        for k in range(K+1):

            sps_grad = cost(x[k],batch[k])/la.norm(grad(x[k],batch[k])+1e-4)**2
            if k==0:
                c[0] = c_init
                gammas[0] = min([sps_grad,c[0]*gamma_max])/c[0]
            else:
                if decr == 'sqrt':
                    c[k] = c_init*np.sqrt(k+1)
                else:
                    c[k] = c_init*(k+1)
                
                gammas[k] = min([sps_grad, c[k-1]*gammas[k-1]])/c[k]
            
            #record
            if k == K_record_times[i_record]:
                gammas_rec[i_record, e] = gammas[k]
                f[i_record,e] = cost(x[k], range(n))
                x_rec[i_record,e,:] = x[k]
                i_record = i_record+1   
            # update
            x[k+1] = x[k] - gammas[k]*grad(x[k],batch[k])
           
    ## name            
    name = r'DecSPS, $c_0='+"{:.2f}".format(c_init)+', \gamma_{b}='+"{:.0f}".format(gamma_max)+'$'
    
    return name, f, gammas_rec, x_rec